{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71a329f0",
   "metadata": {},
   "source": [
    "# LLAMA 7B Chat model demo deployment guide for Neuron instances\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "- S3 bucket push access\n",
    "- SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fa3208",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ac353",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import io\n",
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()  # account_id of the current SageMaker Studio environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81deac79",
   "metadata": {},
   "source": [
    "## Step 2: Start preparing model artifacts\n",
    "In LMI contianer, we expect some artifacts to help setting up the model, in this specific example you would replace the s3 bucket locations with the location of your models and compiled model artifacts.\n",
    "- serving.properties (required): Defines the model server settings\n",
    "- model.py (optional): A python file to define the core inference logic\n",
    "- requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b011bf5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "engine=Python\n",
    "option.model_id=TheBloke/Llama-2-7B-Chat-fp16\n",
    "option.max_rolling_batch_size=4\n",
    "option.tensor_parallel_degree=12\n",
    "option.n_positions=1280\n",
    "option.rolling_batch=auto\n",
    "option.model_loading_timeout=2400\n",
    "option.output_formatter=jsonlines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e37439ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "from djl_python import Input, Output\n",
    "from djl_python.rolling_batch.neuron_rolling_batch import NeuronRollingBatch\n",
    "from djl_python.properties_manager.tnx_properties import TransformerNeuronXProperties\n",
    "from djl_python.transformers_neuronx import TransformersNeuronXService\n",
    "from djl_python.neuron_utils.utils import task_from_config\n",
    "\n",
    "class CustomTNXService(TransformersNeuronXService):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.rolling_batch_config = dict()\n",
    "\n",
    "    def set_configs(self, properties):\n",
    "        self.config = TransformerNeuronXProperties(**properties)\n",
    "        if self.config.rolling_batch != \"disable\":\n",
    "            \"\"\"batch_size needs to match max_rolling_batch_size for precompiled neuron models running rolling batch\"\"\"\n",
    "            self.config.batch_size = self.config.max_rolling_batch_size\n",
    "            if \"output_formatter\" in properties:\n",
    "                self.rolling_batch_config[\"output_formatter\"] = properties.get(\n",
    "                    \"output_formatter\")\n",
    "\n",
    "        self.model_config = AutoConfig.from_pretrained(\n",
    "            self.config.model_id_or_path, revision=self.config.revision)\n",
    "\n",
    "        self.set_model_loader_class()\n",
    "        if not self.config.task:\n",
    "            self.config.task = task_from_config(self.model_config)\n",
    "\n",
    "    def set_rolling_batch(self):\n",
    "        if self.config.rolling_batch != \"disable\":\n",
    "            self.rolling_batch = NeuronRollingBatch(\n",
    "                self.model, self.tokenizer, self.config.batch_size,\n",
    "                self.config.n_positions, **self.rolling_batch_config)\n",
    "\n",
    "\n",
    "_service = CustomTNXService()\n",
    "\n",
    "def handle(inputs: Input):\n",
    "    global _service\n",
    "    if not _service.initialized:\n",
    "        _service.initialize(inputs.get_properties())\n",
    "\n",
    "    if inputs.is_empty():\n",
    "        # Model server makes an empty call to warm up the model on startup\n",
    "        return None\n",
    "\n",
    "    return _service.inference(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd563b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "mv model.py mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e58cf33",
   "metadata": {},
   "source": [
    "## Step 3: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d955679",
   "metadata": {},
   "source": [
    "### Getting the container image URI\n",
    "\n",
    "[Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a174b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-neuronx\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11601839",
   "metadata": {},
   "source": [
    "### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b1e5ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_code_prefix = \"large-model-lmi/code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(\"S3 Code or Model tar ball uploaded\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f39f6",
   "metadata": {},
   "source": [
    "### 4.2 Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e0e61cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type = \"ml.inf2.24xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model-demo\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             container_startup_health_check_timeout=2400,\n",
    "             volume_size=256,\n",
    "             endpoint_name=endpoint_name)\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb63ee65",
   "metadata": {},
   "source": [
    "## Step 5: Test inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79786708",
   "metadata": {},
   "source": [
    "The LineIterator class is used to smooth the output of the token stream."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ccea79",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LineIterator:\n",
    "    \"\"\"\n",
    "    A helper class for parsing the byte stream input. \n",
    "    \n",
    "    The output of the model will be in the following format:\n",
    "    ```\n",
    "    b'{\"outputs\": [\" a\"]}\\n'\n",
    "    b'{\"outputs\": [\" challenging\"]}\\n'\n",
    "    b'{\"outputs\": [\" problem\"]}\\n'\n",
    "    ...\n",
    "    ```\n",
    "    \n",
    "    While usually each PayloadPart event from the event stream will contain a byte array \n",
    "    with a full json, this is not guaranteed and some of the json objects may be split across\n",
    "    PayloadPart events. For example:\n",
    "    ```\n",
    "    {'PayloadPart': {'Bytes': b'{\"outputs\": '}}\n",
    "    {'PayloadPart': {'Bytes': b'[\" problem\"]}\\n'}}\n",
    "    ```\n",
    "    \n",
    "    This class accounts for this by concatenating bytes written via the 'write' function\n",
    "    and then exposing a method which will return lines (ending with a '\\n' character) within\n",
    "    the buffer via the 'scan_lines' function. It maintains the position of the last read \n",
    "    position to ensure that previous bytes are not exposed again. \n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, stream):\n",
    "        self.byte_iterator = iter(stream)\n",
    "        self.buffer = io.BytesIO()\n",
    "        self.read_pos = 0\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        while True:\n",
    "            self.buffer.seek(self.read_pos)\n",
    "            line = self.buffer.readline()\n",
    "            if line and line[-1] == ord('\\n'):\n",
    "                self.read_pos += len(line)\n",
    "                return line[:-1]\n",
    "            try:\n",
    "                chunk = next(self.byte_iterator)\n",
    "            except StopIteration:\n",
    "                if self.read_pos < self.buffer.getbuffer().nbytes:\n",
    "                    continue\n",
    "                raise\n",
    "            if 'PayloadPart' not in chunk:\n",
    "                print('Unknown event type:' + chunk)\n",
    "                continue\n",
    "            self.buffer.seek(0, io.SEEK_END)\n",
    "            self.buffer.write(chunk['PayloadPart']['Bytes'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09f3245f",
   "metadata": {},
   "source": [
    "This is a simple method that we will use for running our line iterator on a prompt, and providing a maximum number of tokens to infer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e178d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat_stream_inference(prompt, tokens):\n",
    "    body = {\"inputs\": prompt, \"parameters\": {\"max_new_tokens\": tokens, \"details\":True}}\n",
    "    resp = sess.sagemaker_runtime_client.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType=\"application/json\")\n",
    "    event_stream = resp['Body']\n",
    "\n",
    "    for line in LineIterator(event_stream):\n",
    "        resp = json.loads(line)\n",
    "        if resp.get(\"token\").get(\"text\") is not None:\n",
    "            print(resp.get(\"token\").get(\"text\"), end='')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc09a872",
   "metadata": {},
   "source": [
    "### Text Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e86e5770",
   "metadata": {},
   "outputs": [],
   "source": [
    "first_prompt = \"[INST] Tell me the story of Red Riding Hood in a 4 act play with lots of emojis [/INST]\"\n",
    "\n",
    "chat_stream_inference(first_prompt, 1280)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec770bfd",
   "metadata": {},
   "source": [
    "### Translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab4e34f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "second_prompt = \"\"\"Translate English to French:\n",
    "sea otter => loutre de mer\n",
    "peppermint => menthe poivrée\n",
    "plush girafe => girafe peluche\n",
    "cheese => \"\"\"\n",
    "\n",
    "chat_stream_inference(second_prompt, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c2c9537",
   "metadata": {},
   "source": [
    "### Question Answering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692c02a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "third_prompt = \"Could you remind me when was the C programming language invented?\"\n",
    "\n",
    "chat_stream_inference(third_prompt, 25)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8ce8393",
   "metadata": {},
   "source": [
    "### Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af851e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "fourth_prompt = \"\"\"Tweet: \"I hate it when my phone battery dies.\"\n",
    "Sentiment: Negative\n",
    "###\n",
    "Tweet: \"My day has been :+1:\"\n",
    "Sentiment: Positive\n",
    "###\n",
    "Tweet: \"This is the link to the article\"\n",
    "Sentiment: Neutral\n",
    "###\n",
    "Tweet: \"This new music video was incredibile\"\n",
    "Sentiment:\"\"\"\n",
    "\n",
    "chat_stream_inference(fourth_prompt, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba182e39",
   "metadata": {},
   "source": [
    "### Summarization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60ad4b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "fifth_prompt = \"\"\"Your task is to summarize the following article into five bullet points. The article will be denoted by three backticks.\n",
    "```\n",
    "Amazon SageMaker is a cloud based machine-learning platform that allows the creation, training, and deployment by developers of machine-learning (ML) models on the cloud. It can be used to deploy ML models on embedded systems and edge-devices. SageMaker was launched in November 2017.\n",
    "\n",
    "Capabilities\n",
    "SageMaker enables developers to operate at a number of different levels of abstraction when training and deploying machine learning models. At its highest level of abstraction, SageMaker provides pre-trained ML models that can be deployed as-is. In addition, SageMaker provides a number of built-in ML algorithms that developers can train on their own data. Further, SageMaker provides managed instances of TensorFlow and Apache MXNet, where developers can create their own ML algorithms from scratch. Regardless of which level of abstraction is used, a developer can connect their SageMaker-enabled ML models to other AWS services, such as the Amazon DynamoDB database for structured data storage, AWS Batch for offline batch processing, or Amazon Kinesis for real-time processing.\n",
    "\n",
    "Development interfaces\n",
    "A number of interfaces are available for developers to interact with SageMaker. First, there is a web API that remotely controls a SageMaker server instance. While the web API is agnostic to the programming language used by the developer, Amazon provides SageMaker API bindings for a number of languages, including Python, JavaScript, Ruby, Java, and Go. In addition, SageMaker provides managed Jupyter Notebook instances for interactively programming SageMaker and other applications.\n",
    "```\n",
    "In summary what are the main points of the article above?\"\"\"\n",
    "\n",
    "chat_stream_inference(fifth_prompt, 250)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cd9042",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c5d883",
   "metadata": {},
   "source": [
    "Uncomment the cell below and run to cleanup endpoint after testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d674b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
